import matplotlib.pyplot as plt
import numpy as np

# Creating the Sudoku grid
sudoku = np.array([
    [0, 3, 0, 0],
    [0, 4, 0, 0],
    [0, 2, 1, 4],
    [4, 1, 3, 2]
])


def draw_sudoku(puzzle, save_path):

    # Setting up the plot with axes
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.matshow(np.zeros_like(sudoku), cmap='Greys', alpha=0.3)

    # Adding the numbers to the grid
    for i in range(4):
        for j in range(4):
            num = sudoku[i, j]
            if num != 0:
                ax.text(j, i, str(num), va='center', ha='center', fontsize=20)

    # Drawing the grid lines
    for x in range(5):
        ax.axhline(x - 0.5, lw=2, color='black', alpha=0.7)
        ax.axvline(x - 0.5, lw=2, color='black', alpha=0.7)

    # Adding axis labels and indices
    ax.set_xticks(np.arange(4))
    ax.set_yticks(np.arange(4))
    ax.set_xticklabels(np.arange(1, 5), color='red')
    ax.set_yticklabels(np.arange(1, 5), color='red')

    # Labeling the axes
    ax.set_xlabel('X Axis', color='red')
    ax.set_ylabel('Y Axis', color='red')

    plt.savefig(save_path, dpi=300)
